import os
import glob
import h5py
import argparse

from tqdm import tqdm, trange

def merge_hdf5(
    base_path,
    dirname
):
    data_file = glob.glob(os.path.join(base_path, "part*", dirname, "data.hdf5"))
    try:
        os.remove(os.path.join(base_path, "data_train.hdf5"))
    except OSError as e:
        print(f"[OSError]: {e}.")
        pass
    new_data = h5py.File(os.path.join(base_path, "data_train.hdf5"), "a")

    for file_idx, h5name in (ppbar := tqdm(enumerate(data_file), total=len(data_file), ncols=0)):
        ppbar.set_postfix({"file_idx": file_idx})
        h5fr = h5py.File(h5name,'r') 
        for obj in (pbar := tqdm(h5fr.keys(), ncols=0, leave=False)):
            pbar.set_postfix({"col": obj})
            chunk_size = 1000 if obj != "ob" else 1
            for idx in trange(0, len(h5fr[obj]), chunk_size, leave=False):
                chunk = list(range(idx, min(idx + chunk_size, len(h5fr[obj]))))
                if file_idx == 0 and idx == 0:
                    data_shape = h5fr[obj][chunk].shape[1:]
                    new_data.create_dataset(obj, compression="gzip", chunks=(1, *data_shape), maxshape=(1e6, *data_shape), data=h5fr[obj][chunk])
                else:
                    dataset = new_data[obj]
                    dataset.resize((dataset.shape[0] + len(chunk)), axis=0)
                    dataset[-len(chunk):] = h5fr[obj][chunk]
        h5fr.close()

    new_data.close()
            
            
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Process rollout training arguments.')
    parser.add_argument('--env_name', type=str, default='coinrun')
    parser.add_argument('--env_type', type=str, default='none')
    parser.add_argument('--num_levels', type=int, default=500)
    parser.add_argument('--start_level', type=int, default=0)
    parser.add_argument('--distribution_mode', type=str, default='hard')
    parser.add_argument('--image_keys', type=str, default='ob')
    parser.add_argument('--base_path', type=str, default=None, required=True)

    parser.add_argument('--num_demonstrations', type=int, default=1000)
    parser.add_argument('--save_type', type=str, default='npy', choices=['npy', 'hdf5'])
    parser.add_argument('--num_frames', type=int, default=4)

    parser.add_argument('--model_type', type=str, default='clip', choices=['clip', 'ts2net', 'mugen', 'mugen_finetune', 'clip_finetune', 'clip_action_finetune'])
    parser.add_argument('--model_ckpt_dir', type=str, default=None)

    args = parser.parse_args()
    dirname = f"{args.env_name}_{args.distribution_mode}_level{args.start_level}to{args.num_levels}_num{args.num_demonstrations}_frame{args.num_frames}"
    if args.env_type != "none":
        dirname += f"_{args.env_type}"

    print(f"dirname: {dirname}")
 
    merge_hdf5(
        base_path=args.base_path,
        dirname=dirname
    )